{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 多GPU的简洁实现\n",
    ":label:`sec_multi_gpu_concise`\n",
    "\n",
    "每个新模型的并行计算都从零开始实现是无趣的。此外，优化同步工具以获得高性能也是有好处的。下面我们将展示如何使用深度学习框架的高级API来实现这一点。数学和算法与 :numref:`sec_multi_gpu`中的相同。不出所料，你至少需要两个GPU来运行本节的代码。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Training.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.basicdataset.cv.classification.*;\n",
    "import ai.djl.metric.*;\n",
    "import org.apache.commons.lang3.ArrayUtils;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 简单网络\n",
    "\n",
    "让我们使用一个比 :numref:`sec_multi_gpu`的LeNet更有意义的网络，它依然能够容易地和快速地训练。我们选择的是 :cite:`He.Zhang.Ren.ea.2016`中的ResNet-18。因为输入的图像很小，所以稍微修改了一下。与 :numref:`sec_resnet`的区别在于，我们在开始时使用了更小的卷积核、步长和填充，而且删除了最大汇聚层。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Residual extends AbstractBlock {\n",
    "\n",
    "    private static final byte VERSION = 2;\n",
    "\n",
    "    public ParallelBlock block;\n",
    "\n",
    "    public Residual(int numChannels, boolean use1x1Conv, Shape strideShape) {\n",
    "        super(VERSION);\n",
    "\n",
    "        SequentialBlock b1;\n",
    "        SequentialBlock conv1x1;\n",
    "\n",
    "        b1 = new SequentialBlock();\n",
    "\n",
    "        b1.add(Conv2d.builder()\n",
    "                .setFilters(numChannels)\n",
    "                .setKernelShape(new Shape(3, 3))\n",
    "                .optPadding(new Shape(1, 1))\n",
    "                .optStride(strideShape)\n",
    "                .build())\n",
    "                .add(BatchNorm.builder().build())\n",
    "                .add(Activation::relu)\n",
    "                .add(Conv2d.builder()\n",
    "                        .setFilters(numChannels)\n",
    "                        .setKernelShape(new Shape(3, 3))\n",
    "                        .optPadding(new Shape(1, 1))\n",
    "                        .build())\n",
    "                .add(BatchNorm.builder().build());\n",
    "\n",
    "        if (use1x1Conv) {\n",
    "            conv1x1 = new SequentialBlock();\n",
    "            conv1x1.add(Conv2d.builder()\n",
    "                    .setFilters(numChannels)\n",
    "                    .setKernelShape(new Shape(1, 1))\n",
    "                    .optStride(strideShape)\n",
    "                    .build());\n",
    "        } else {\n",
    "            conv1x1 = new SequentialBlock();\n",
    "            conv1x1.add(Blocks.identityBlock());\n",
    "        }\n",
    "\n",
    "        block = addChildBlock(\"residualBlock\", new ParallelBlock(\n",
    "                list -> {\n",
    "                    NDList unit = list.get(0);\n",
    "                    NDList parallel = list.get(1);\n",
    "                    return new NDList(\n",
    "                            unit.singletonOrThrow()\n",
    "                                    .add(parallel.singletonOrThrow())\n",
    "                                    .getNDArrayInternal()\n",
    "                                    .relu());\n",
    "                },\n",
    "                Arrays.asList(b1, conv1x1)));\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    protected NDList forwardInternal(\n",
    "            ParameterStore parameterStore,\n",
    "            NDList inputs,\n",
    "            boolean training,\n",
    "            PairList<String, Object> params) {\n",
    "        return block.forward(parameterStore, inputs, training);\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    public Shape[] getOutputShapes(Shape[] inputs) {\n",
    "        Shape[] current = inputs;\n",
    "        for (Block block : block.getChildren().values()) {\n",
    "            current = block.getOutputShapes(current);\n",
    "        }\n",
    "        return current;\n",
    "    }\n",
    "\n",
    "    @Override\n",
    "    protected void initializeChildBlocks(NDManager manager, DataType dataType, Shape... inputShapes) {\n",
    "        block.initialize(manager, dataType, inputShapes);\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public SequentialBlock resnetBlock(int numChannels, int numResiduals, boolean isFirstBlock) {\n",
    "\n",
    "        SequentialBlock blk = new SequentialBlock();\n",
    "        for (int i = 0; i < numResiduals; i++) {\n",
    "\n",
    "            if (i == 0 && !isFirstBlock) {\n",
    "                blk.add(new Residual(numChannels, true, new Shape(2, 2)));\n",
    "            } else {\n",
    "                blk.add(new Residual(numChannels, false, new Shape(1, 1)));\n",
    "            }\n",
    "        }\n",
    "        return blk;\n",
    "}\n",
    "\n",
    "int numClass = 10;\n",
    "// This model uses a smaller convolution kernel, stride, and padding and\n",
    "// removes the maximum pooling layer\n",
    "SequentialBlock net = new SequentialBlock();\n",
    "net\n",
    "    .add(\n",
    "            Conv2d.builder()\n",
    "                    .setFilters(64)\n",
    "                    .setKernelShape(new Shape(3, 3))\n",
    "                    .optPadding(new Shape(1, 1))\n",
    "                    .build())\n",
    "    .add(BatchNorm.builder().build())\n",
    "    .add(Activation::relu)\n",
    "    .add(resnetBlock(64, 2, true))\n",
    "    .add(resnetBlock(128, 2, false))\n",
    "    .add(resnetBlock(256, 2, false))\n",
    "    .add(resnetBlock(512, 2, false))\n",
    "    .add(Pool.globalAvgPool2dBlock())\n",
    "    .add(Linear.builder().setUnits(numClass).build());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 网络初始化\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`setInitializer()`函数允许我们在所选设备上初始化参数。请参阅 :numref:`sec_numerical_stability`复习初始化方法。这个函数在多个设备上初始化网络时特别方便。让我们在实践中试一试它的运作方式。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Model model = Model.newInstance(\"training-multiple-gpus-1\");\n",
    "model.setBlock(net);\n",
    "\n",
    "Loss loss = Loss.softmaxCrossEntropyLoss();\n",
    "\n",
    "Tracker lrt = Tracker.fixed(0.1f);\n",
    "Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();\n",
    "\n",
    "DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "        .optOptimizer(sgd) // Optimizer (loss function)\n",
    "        .optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT) // setting the initializer\n",
    "        .optDevices(Engine.getInstance().getDevices(1)) // setting the number of GPUs needed\n",
    "        .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "        .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "Trainer trainer = model.newTrainer(config);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用 :numref:`sec_multi_gpu`中引入的`split()`函数可以切分一个小批量数据，并将切分后的分块数据复制到多个设备设备中。网络实例自动使用适当的GPU来计算前向传播的值。我们将在下面生成$4$个观测值，并在GPU上将它们拆分。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();\n",
    "NDArray X = manager.randomUniform(0f, 1.0f, new Shape(4, 1, 28, 28));\n",
    "trainer.initialize(X.getShape());\n",
    "\n",
    "NDList[] res = Batchifier.STACK.split(new NDList(X), 4, true);\n",
    "\n",
    "ParameterStore parameterStore = new ParameterStore(manager, true);\n",
    "\n",
    "System.out.println(net.forward(parameterStore, new NDList(res[0]), false).singletonOrThrow());\n",
    "System.out.println(net.forward(parameterStore, new NDList(res[1]), false).singletonOrThrow());\n",
    "System.out.println(net.forward(parameterStore, new NDList(res[2]), false).singletonOrThrow());\n",
    "System.out.println(net.forward(parameterStore, new NDList(res[3]), false).singletonOrThrow());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "一旦数据通过网络，网络对应的参数就会在*有数据通过的设备上初始化*。这意味着初始化是基于每个设备进行的。由于我们选择的是GPU0和GPU1，所以网络只在这两个GPU上初始化，而不是在CPU上初始化。事实上，CPU上甚至没有这些参数。我们可以通过打印参数和观察可能出现的任何错误来验证这一点。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.getChildren().values().get(0).getParameters().get(\"weight\").getArray().get(new NDIndex(\"0:1\"));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练\n",
    "\n",
    "如前所述，用于训练的代码需要执行几个基本功能才能实现高效并行：\n",
    "\n",
    "* 需要在所有设备上初始化网络参数。\n",
    "* 在数据集上迭代时，要将小批量数据分配到所有设备上。\n",
    "* 跨设备并行计算损失及其梯度。\n",
    "* 聚合梯度，并相应地更新参数。\n",
    "\n",
    "最后，并行地计算精确度和发布网络的最终性能。除了需要拆分和聚合数据外，训练代码与前几章的实现非常相似。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "int numEpochs = Integer.getInteger(\"MAX_EPOCH\", 10);\n",
    "\n",
    "double[] testAccuracy;\n",
    "double[] epochCount;\n",
    "\n",
    "epochCount = new double[numEpochs];\n",
    "\n",
    "for (int i = 0; i < epochCount.length; i++) {\n",
    "    epochCount[i] = (i + 1);\n",
    "}\n",
    "\n",
    "Map<String, double[]> evaluatorMetrics = new HashMap<>();\n",
    "double avgTrainTimePerEpoch = 0;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public void train(int numEpochs, Trainer trainer, int batchSize) throws IOException, TranslateException {\n",
    "\n",
    "    FashionMnist trainIter = FashionMnist.builder()\n",
    "            .optUsage(Dataset.Usage.TRAIN)\n",
    "            .setSampling(batchSize, true)\n",
    "            .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "            .build();\n",
    "    FashionMnist testIter = FashionMnist.builder()\n",
    "            .optUsage(Dataset.Usage.TEST)\n",
    "            .setSampling(batchSize, true)\n",
    "            .optLimit(Long.getLong(\"DATASET_LIMIT\", Long.MAX_VALUE))\n",
    "            .build();\n",
    "\n",
    "    trainIter.prepare();\n",
    "    testIter.prepare();\n",
    "\n",
    "    Map<String, double[]> evaluatorMetrics = new HashMap<>();\n",
    "    double avgTrainTime = 0;\n",
    "\n",
    "    trainer.setMetrics(new Metrics());\n",
    "\n",
    "    EasyTrain.fit(trainer, numEpochs, trainIter, testIter);\n",
    "\n",
    "    Metrics metrics = trainer.getMetrics();\n",
    "\n",
    "    trainer.getEvaluators().stream()\n",
    "            .forEach(evaluator -> {\n",
    "                evaluatorMetrics.put(\"validate_epoch_\" + evaluator.getName(), metrics.getMetric(\"validate_epoch_\" + evaluator.getName()).stream()\n",
    "                        .mapToDouble(x -> x.getValue().doubleValue()).toArray());\n",
    "            });\n",
    "\n",
    "    avgTrainTime = metrics.mean(\"epoch\");\n",
    "    testAccuracy = evaluatorMetrics.get(\"validate_epoch_Accuracy\");\n",
    "    System.out.printf(\"test acc %.2f\\n\" , testAccuracy[numEpochs-1]);\n",
    "    System.out.println(avgTrainTime / Math.pow(10, 9) + \" sec/epoch \\n\");\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 实践\n",
    "\n",
    "让我们看看这在实践中是如何运作的。我们先在单个GPU上训练网络进行预热。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Table data = null;\n",
    "// We will check if we have at least 1 GPU available. If yes, we run the training on 1 GPU.\n",
    "if (Engine.getInstance().getGpuCount() >= 1) {\n",
    "    train(numEpochs, trainer, 256);\n",
    "\n",
    "    data = Table.create(\"Data\");\n",
    "    data = data.addColumns(\n",
    "            DoubleColumn.create(\"X\", epochCount), \n",
    "            DoubleColumn.create(\"testAccuracy\", testAccuracy)\n",
    "    );\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// 以下代码需要你有至少一个GPU设备\n",
    "// render(LinePlot.create(\"\", data, \"x\", \"testAccuracy\"), \"text/html\");    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Contour Gradient Descent.](https://d2l-java-resources.s3.amazonaws.com/img/training-with-1-gpu.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Table data = Table.create(\"Data\");\n",
    "\n",
    "// We will check if we have more than 1 GPU available. If yes, we run the training on 2 GPU.\n",
    "if (Engine.getInstance().getGpuCount() > 1) {\n",
    "\n",
    "    X = manager.randomUniform(0f, 1.0f, new Shape(1, 1, 28, 28));\n",
    "\n",
    "    Model model = Model.newInstance(\"training-multiple-gpus-2\");\n",
    "    model.setBlock(net);\n",
    "\n",
    "    loss = Loss.softmaxCrossEntropyLoss();\n",
    "\n",
    "    Tracker lrt = Tracker.fixed(0.2f);\n",
    "    Optimizer sgd = Optimizer.sgd().setLearningRateTracker(lrt).build();\n",
    "\n",
    "    DefaultTrainingConfig config = new DefaultTrainingConfig(loss)\n",
    "                .optOptimizer(sgd) // Optimizer (loss function)\n",
    "                .optInitializer(new NormalInitializer(0.01f), Parameter.Type.WEIGHT) // setting the initializer\n",
    "                .optDevices(Engine.getInstance().getDevices(2)) // setting the number of GPUs needed\n",
    "                .addEvaluator(new Accuracy()) // Model Accuracy\n",
    "                .addTrainingListeners(TrainingListener.Defaults.logging()); // Logging\n",
    "\n",
    "    Trainer trainer = model.newTrainer(config);\n",
    "    \n",
    "    trainer.initialize(X.getShape());\n",
    "\n",
    "    Map<String, double[]> evaluatorMetrics = new HashMap<>();\n",
    "    double avgTrainTimePerEpoch = 0;\n",
    "\n",
    "    train(numEpochs, trainer, 512);\n",
    "    \n",
    "    data = data.addColumns(\n",
    "        DoubleColumn.create(\"X\", epochCount), \n",
    "        DoubleColumn.create(\"testAccuracy\", testAccuracy)\n",
    "    );\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// 以下代码需要你有两个以上GPU设备\n",
    "// render(LinePlot.create(\"\", data, \"x\", \"testAccuracy\"), \"text/html\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![Contour Gradient Descent.](https://d2l-java-resources.s3.amazonaws.com/img/training-with-2-gpu.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* Gluon通过提供一个上下文列表，为跨多个设备的模型初始化提供原语。\n",
    "* 神经网络可以在（可找到数据的）单GPU上进行自动评估。\n",
    "* 每台设备上的网络需要先初始化，然后再尝试访问该设备上的参数，否则会遇到错误。\n",
    "* 优化算法在多个GPU上自动聚合。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 练习\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. 本节使用ResNet-18，请尝试不同的迭代周期数、批量大小和学习率，以及使用更多的GPU进行计算。如果使用$8$个GPU（例如，在AWS p2.16xlarge实例上）尝试此操作，会发生什么？\n",
    "1. 有时候不同的设备提供了不同的计算能力，我们可以同时使用GPU和CPU，那应该如何分配工作？为什么？\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
